{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1 align=\"center\">Colocalization and Distance Measurements of Objects in Fluorescence Microscopy</h1>\n",
    "\n",
    "**Summary:**\n",
    "\n",
    "1. SimpleITK provides a large number of filters that enable segmentation and quantification of the segmented objects characteristics and spatial relationships between the objects.\n",
    "\n",
    "This notebook will illustrate the construction of a SimpleITK based analysis workflow in which we quantify the colocalization of two markers (FITC and Cy3) and the distance between the protein blobs these define to the nucleus blob(s) defined by a third marker (DAPI). \n",
    "\n",
    "The image we work with was obtained by 3D Structured Illumination Microscopy (3D-SIM) and is provided courtesy of the Etienne Leygue lab at CancerCare Manitoba and The Genomic Centre for Cancer Research and Diagnosis."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import SimpleITK as sitk\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "\n",
    "%matplotlib notebook\n",
    "import gui\n",
    "\n",
    "%run update_path_to_download_script\n",
    "from downloaddata import fetch_data as fdata\n",
    "\n",
    "from IPython.core.display import display, HTML\n",
    "\n",
    "# Always write output to a separate directory, we don't want to pollute the source directory. \n",
    "import os\n",
    "OUTPUT_DIR = 'Output'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load data\n",
    "Load the 3D multi channel structured illumination microscopy image, split it into the separate channels and display them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "sim_image = sitk.ReadImage(fdata('microscopy_colocalization.nrrd'))\n",
    "\n",
    "#Our original image was saved using mm as the unit size, but in microscopy the more common unit is um.\n",
    "#We modify the size so that it is in um. If your image is in um you won't have to perform this modification.\n",
    "sim_image.SetSpacing([spc*1000 for spc in sim_image.GetSpacing()])\n",
    "sim_image.SetOrigin([org*1000 for org in sim_image.GetOrigin()])\n",
    "\n",
    "# The channel name appears in the image's meta-data dictionary with key: channel_i_name\n",
    "channel_titles = [sim_image.GetMetaData('channel_{0}_name'.format(i)) for i in range(sim_image.GetNumberOfComponentsPerPixel())]\n",
    "channels = [sitk.VectorIndexSelectionCast(sim_image,i) \n",
    "            for i in range(sim_image.GetNumberOfComponentsPerPixel())]\n",
    "fitc_image = channels[channel_titles.index('FITC')]\n",
    "cy3_image = channels[channel_titles.index('Cy3')]\n",
    "dapi_image = channels[channel_titles.index('DAPI')]\n",
    "\n",
    "gui.MultiImageDisplay(image_list=channels, title_list=channel_titles, shared_slider=True, \n",
    "                      intensity_slider_range_percentile=[0,100],\n",
    "                      figure_size=(20,10));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Segmenting Channels\n",
    "\n",
    "Based on the visualization above we perform manual segmentation, visually estimating thresholds for the three channels, isolating our objects of interest.\n",
    "\n",
    "As we know that the image contains a single nucleus we take the largest connected component in the thresholded DAPI channel as the nucleus. When there are multiple nuclei more sophisticated methods may be needed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fitc_threshold = 10000\n",
    "cy3_threshold = 10000\n",
    "dapi_threshold = 12000\n",
    "\n",
    "# Segment the nucleus marker image and then label each connected component. \n",
    "dapi_binary_segmentation = dapi_image > dapi_threshold\n",
    "\n",
    "dapi_labeled_segmentation = sitk.ConnectedComponent(dapi_binary_segmentation)\n",
    "# Our simple thresholding creates many small connected components. We have a \n",
    "# single nucleus so we relabel and take the largest connected component.\n",
    "# Minimal object size is given as number of voxels.\n",
    "dapi_labeled_segmentation = sitk.RelabelComponent(dapi_labeled_segmentation, minimumObjectSize=1000, sortByObjectSize=True)\n",
    "dapi_binary_segmentation = dapi_labeled_segmentation == 1\n",
    "\n",
    "dapi_stats_filter = sitk.LabelShapeStatisticsImageFilter()\n",
    "dapi_stats_filter.Execute(dapi_labeled_segmentation)\n",
    "\n",
    "# Segment the protein marker images and label each connected component.\n",
    "markers_binary_segmentations = [fitc_image > fitc_threshold,\n",
    "                                cy3_image > cy3_threshold]\n",
    "markers_labeled_segmentations = [sitk.ConnectedComponent(marker_binary_segmentation) \\\n",
    "                                 for marker_binary_segmentation in markers_binary_segmentations]\n",
    "gui.MultiImageDisplay([dapi_binary_segmentation]);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Colocalization\n",
    "\n",
    "The locations of the specific protein we are interested in correspond to locations where there is overlap between the markers. In our case FITC and Cy3.\n",
    "\n",
    "In the next cell we compute various size characteristics of the colocalized markers (voxels, volume, percentage of overlap from original markers)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from numba import njit\n",
    "\n",
    "@njit\n",
    "def first_index(arr, val):\n",
    "    '''\n",
    "    Find the index of the first appearance of val in arr. The loop is extremely slow in Python, but we use\n",
    "    the numba just-in-time compiler to optimize it (@njit decorator). Another option is to use np.where \n",
    "    which returns all indexes but that is significantly (six times) slower than using the numba optimized function.\n",
    "    '''    \n",
    "    for index, value in np.ndenumerate(arr):\n",
    "        if val == value:\n",
    "             return index\n",
    "\n",
    "markers_names = ['FITC', 'Cy3']\n",
    "\n",
    "# Compute the label statistics for all protein markers. \n",
    "def apply_label_shape_filter(label_image):\n",
    "    filt=sitk.LabelShapeStatisticsImageFilter()\n",
    "    filt.Execute(label_image)\n",
    "    return filt\n",
    "markers_stats_filters = [apply_label_shape_filter(labeled_seg) for labeled_seg in markers_labeled_segmentations]\n",
    "\n",
    "# Compute the colocalization image and label each connected component.\n",
    "colocalization_binary_segmentation = markers_binary_segmentations[0]\n",
    "for binary_segmentation in markers_binary_segmentations[1:]:\n",
    "    colocalization_binary_segmentation = colocalization_binary_segmentation * binary_segmentation\n",
    "colocolization_labeled_segmentation = sitk.ConnectedComponent(colocalization_binary_segmentation)    \n",
    "\n",
    "# Compute the label statistics for the colocalized protein markers.\n",
    "colocolization_stats_filter = sitk.LabelShapeStatisticsImageFilter()\n",
    "colocolization_stats_filter.Execute(colocolization_labeled_segmentation)\n",
    "\n",
    "# Create a dictionary that maps between the colocalization labels and the individual labels from\n",
    "# each marker.\n",
    "colocalization_labels_2_original_labels = {}\n",
    "colocolization_labeled_segmentation_arr_view = sitk.GetArrayViewFromImage(colocolization_labeled_segmentation)\n",
    "for label in colocolization_stats_filter.GetLabels():\n",
    "    # The index into the numpy array needs to be flipped as the order in numpy is zyx and in SimpleITK xyz\n",
    "    index = first_index(colocolization_labeled_segmentation_arr_view, label)[::-1]\n",
    "    colocalization_labels_2_original_labels[label] = [labeled_seg[index] for labeled_seg in markers_labeled_segmentations]\n",
    "    \n",
    "# Compute statistics for the colocalizations. Work with a list of lists and then \n",
    "# combine into a dataframe, faster than appending to the dataframe one by one.\n",
    "column_titles = ['colocalization size']*2 + [item for sublist in [[marker]*4 for marker in markers_names] for item in sublist]\n",
    "all_colocalizations_data = []\n",
    "for item in colocalization_labels_2_original_labels.items():\n",
    "    coloc_size = colocolization_stats_filter.GetPhysicalSize(item[0])\n",
    "    marker_labels_list = item[1]\n",
    "    current_colocalization = [coloc_size, colocolization_stats_filter.GetNumberOfPixels(item[0])] + \\\n",
    "                             [item for sublist in [[label,filt.GetPhysicalSize(label),filt.GetNumberOfPixels(label),coloc_size/filt.GetPhysicalSize(label)] \n",
    "                                                   for label,filt in zip(marker_labels_list, markers_stats_filters)] for item in sublist]\n",
    "    all_colocalizations_data.append(current_colocalization)\n",
    "    \n",
    "colocalization_information_df = \\\n",
    "    pd.DataFrame(all_colocalizations_data, columns=column_titles)\n",
    "marker_columns = ['label', 'size [um^3]', 'size[voxels]', 'colocalization percentage']\n",
    "colocalization_information_df.columns = pd.MultiIndex.from_tuples(zip(colocalization_information_df.columns, \n",
    "                                          ['um^3', 'voxels'] + [item for sublist in [marker_columns for item in markers_names] for item in sublist]))\n",
    "# Save the colocalization results\n",
    "colocalization_information_df.to_csv(os.path.join(OUTPUT_DIR,'colocalization.csv'), index=False)\n",
    "\n",
    "# Display the first N rows as HTML\n",
    "head_length=20\n",
    "display(HTML(colocalization_information_df.head(head_length).to_html(index=False)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Distance of colocalizations from nuclei\n",
    "\n",
    "We now compute the distances between marker colocalization blobs to nuclei blobs, both edge (boundary) to edge distance and center to center distance. The centroid of a blob is defined as the location in physical space of the mean of the blob voxel locations. Consequentially, if the blob is not convex the centroid may lie outside the blob.\n",
    "\n",
    "Edge to edge distances are computed using a distance map, this treats all nuclei as a single object."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Compute the edge to edge distances using the absolute distance map from the nuclei segmentation.\n",
    "distance_map_from_all_nuclei = sitk.Abs(sitk.SignedMaurerDistanceMap(dapi_binary_segmentation, squaredDistance=False, useImageSpacing=True))\n",
    "\n",
    "distance_stats_filter = sitk.LabelIntensityStatisticsImageFilter()\n",
    "distance_stats_filter.Execute(colocolization_labeled_segmentation, distance_map_from_all_nuclei)\n",
    "\n",
    "labels_edge_distances = []\n",
    "for label in distance_stats_filter.GetLabels():\n",
    "    # Using minimum for each label gives us edge to edge distance\n",
    "    labels_edge_distances.append(colocalization_labels_2_original_labels[label] + [distance_stats_filter.GetMinimum(label)]) \n",
    "    \n",
    "ee_distances_df = pd.DataFrame(labels_edge_distances, columns = markers_names + ['edge edge distance to DAPI [um]'])\n",
    "ee_distances_df.to_csv(os.path.join(OUTPUT_DIR,'colocalization_edge_edge_distances.csv'), index=False)\n",
    "\n",
    "# Display the first N rows as HTML, sorted according to distance from nuclei\n",
    "head_length=20\n",
    "display(HTML(ee_distances_df.sort_values(by='edge edge distance to DAPI [um]').head(head_length).to_html(index=False)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Trust but Verify\n",
    "\n",
    "Visually check that the colocalizations with zero distance indeed agree with our computations using the distance map.\n",
    "\n",
    "Scroll through the image stack and then zoom in on a region with \"light\" pixels using the zoom tool (box /rubber band menu item). Zoom in in both images and hover with the mouse over the light pixels. On the bottom right you will see the pixel location and value. The pixel value is the label of that pixel. The combination of labels should match the table that appears above this cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "zero_distance_mask = sitk.Cast(colocalization_binary_segmentation*dapi_binary_segmentation, markers_labeled_segmentations[0].GetPixelIDValue())\n",
    "\n",
    "FITC_labels_zero_distance = zero_distance_mask*markers_labeled_segmentations[0]\n",
    "Cy3_labels_zero_distance = zero_distance_mask*markers_labeled_segmentations[1]\n",
    "\n",
    "gui.MultiImageDisplay(image_list=[FITC_labels_zero_distance, Cy3_labels_zero_distance], \n",
    "                      title_list=['FITC labels', 'Cy3 labels'], shared_slider=True, \n",
    "                      intensity_slider_range_percentile=[0,100],\n",
    "                      figure_size=(10,5));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Distance of colocalizations from nuclei centers\n",
    "\n",
    "Determining how close proteins (or genes if that is your interest) is to the center of the nucleus. Note that this distance is not symmetric, the distance between a protein to the closest nucleus is not the same as the distance between that nucleus and the closest protein."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "colocalization_labels, colocalization_centroids = zip(*[(coloc_label, colocolization_stats_filter.GetCentroid(coloc_label)) for coloc_label in colocolization_stats_filter.GetLabels()])\n",
    "colocalization_centroids = np.array(colocalization_centroids)\n",
    "\n",
    "nuclei_labels, nuclei_centroids = zip(*[(nucleus, dapi_stats_filter.GetCentroid(nucleus)) for nucleus in dapi_stats_filter.GetLabels()])\n",
    "nuclei_centroids = np.array(nuclei_centroids)\n",
    "\n",
    "# Compute minimal distances and matching labels\n",
    "all_distances = -2 * np.dot(colocalization_centroids, nuclei_centroids.T)\n",
    "all_distances += np.sum(colocalization_centroids**2, axis=1)[:, np.newaxis]\n",
    "all_distances += np.sum(nuclei_centroids**2, axis=1)\n",
    "all_distances = np.sqrt(all_distances)\n",
    "\n",
    "min_indexes = np.argmin(all_distances, axis=1)\n",
    "\n",
    "results = list(zip(colocalization_labels, tuple(np.array(nuclei_labels)[min_indexes]), all_distances[np.arange(len(min_indexes)), min_indexes]))\n",
    "# Replace the colocalization labels with the original channel labels\n",
    "results = [colocalization_labels_2_original_labels[coloc_label] + list((nucleus_label, distance)) for coloc_label, nucleus_label, distance in results]\n",
    "\n",
    "cc_distances_df = pd.DataFrame(results, columns = markers_names + ['DAPI', 'centroid centroid distance [um]'])\n",
    "cc_distances_df.to_csv(os.path.join(OUTPUT_DIR,'colocalization_centroid_centroid_distances.csv'), index=False)\n",
    "\n",
    "# Display the first N rows as HTML\n",
    "head_length=20\n",
    "display(HTML(cc_distances_df.head(head_length).to_html(index=False)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
